# Copyright 2024 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for Langfun evaluation."""

import dataclasses
import inspect
from typing import Any, Callable, Iterator
import langfun.core as lf
import pyglove as pg


@dataclasses.dataclass
class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
  """An example for evaluation.

  An evaluation example contains the input and output of an evaluation task,
  as well as metadata about the evaluation process, such as execution time,
  LLM usage, and metric results.

  Attributes:
    id: The 1-based ID of the example in the evaluation set.
    input: An element returned from the `Evaluable.inputs` functor, which serves
      as the input for `lf.Evaluable.process`.
    output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
      it indicates the example has not been processed yet.
    error: The error raised from `lf.Evaluable.process`. If None, it
      indicates the process was successful.
    metadata: The metadata of the example produced by `lf.Evaluable.process`.
    metric_metadata: The dictionary returned from `Metric.audit`, which contains
      metadata about metric computation for this example.
    newly_processed: Whether this example is processed in the current run. If
      False, it indicates the example was loaded from a checkpoint from previous
      runs.
    start_time: The start time of processing this example.
    end_time: The end time of processing this example.
    usage_summary: The summary of LLM usages for processing this example.
    execution_status: The timeit status of processing this example.
  """
  id: int
  input: Any = pg.MISSING_VALUE
  output: Any = pg.MISSING_VALUE
  error: pg.ErrorInfo | None = None
  metadata: dict[str, Any] = dataclasses.field(default_factory=dict)
  metric_metadata: dict[str, Any] | None = None
  # Execution information.
  newly_processed: bool = True
  start_time: float | None = None
  end_time: float | None = None
  usage_summary: lf.UsageSummary | None = None
  execution_status: dict[str, pg.utils.TimeIt.Status] | None = None

  @property
  def is_processed(self) -> bool:
    """Returns whether the item has been processed."""
    return pg.MISSING_VALUE != self.output

  @property
  def has_error(self) -> bool:
    """Returns whether the item has an error."""
    return self.error is not None

  @property
  def elapse(self) -> float | None:
    """Returns the elapse time of the item."""
    if self.execution_status is not None:
      return self.execution_status['evaluate'].elapse
    return None

  def to_json(self, *, exclude_input: bool = False, **kwargs):
    """Returns the JSON representation of the item."""
    return self.to_json_dict(
        fields=dict(
            id=(self.id, None),
            input=(
                self.input if not exclude_input else pg.MISSING_VALUE,
                pg.MISSING_VALUE
            ),
            output=(self.output, pg.MISSING_VALUE),
            error=(self.error, None),
            metadata=(self.metadata, {}),
            metric_metadata=(self.metric_metadata, None),
            start_time=(self.start_time, None),
            end_time=(self.end_time, None),
            usage_summary=(self.usage_summary, None),
            execution_status=(self.execution_status, None),
        ),
        exclude_default=True,
        **kwargs,
    )

  @classmethod
  def from_json(
      cls,
      json_value: dict[str, Any],
      *,
      example_input_by_id: Callable[[int], Any] | None = None,
      load_example_metadata: bool | Callable[['Example'], bool] = False,
      **kwargs
  ) -> 'Example':
    """Creates an example from the JSON representation."""
    example_id = json_value.get('id')
    if example_input_by_id:
      example_input = example_input_by_id(example_id)
    else:
      example_input = json_value.pop('input', pg.MISSING_VALUE)
      if example_input is not pg.MISSING_VALUE:
        example_input = pg.from_json(example_input, **kwargs)
    json_value['input'] = example_input

    # NOTE(daiyip): We need to load the types of the examples into the
    # deserialization context, otherwise the deserialization will fail if the
    # types are not registered.
    def example_class_defs(example) -> list[type[Any]]:
      referred_types = set()
      def _visit(k, v, p):
        del k, p
        if inspect.isclass(v):
          referred_types.add(v)
        elif isinstance(v, pg.Object):
          referred_types.add(v.__class__)
        return pg.TraverseAction.ENTER
      pg.traverse(example, _visit)
      return list(referred_types)

    # We delay loading the metadata until the other parts of the example are
    # loaded. So we could apply the filter to decide whether to load the
    # metadata.
    metadata_dict = json_value.pop('metadata', None)
    with pg.JSONConvertible.load_types_for_deserialization(
        *example_class_defs(example_input)
    ):
      example = cls(
          **{k: pg.from_json(v, **kwargs) for k, v in json_value.items()}
      )
      if callable(load_example_metadata):
        load_example_metadata = load_example_metadata(example)
      if load_example_metadata:
        example.metadata = pg.from_json(metadata_dict, **kwargs)
      return example

  @classmethod
  def iter_ckpts(
      cls,
      ckpt_file: str | list[str],
      example_input_by_id: Callable[[int], Any] | None = None,
      load_example_metadata: bool = True,
      convert_unknown: bool = True,
      **kwargs
  ) -> Iterator['Example']:
    """Iterates Examples from the checkpoint files."""
    ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
    for ckpt_file in ckpt_files:
      with pg.io.open_sequence(ckpt_file) as f:
        for record in f:
          example = pg.from_json_str(
              record,
              example_input_by_id=example_input_by_id,
              load_example_metadata=load_example_metadata,
              convert_unknown=convert_unknown,
              **kwargs
          )
          assert isinstance(example, cls), example
          yield example

  #
  # HTML rendering.
  #

  def _html_tree_view_content(
      self,
      *,
      view: pg.views.HtmlTreeView,
      root_path: pg.KeyPath | None = None,
      extra_flags: dict[str, Any] | None = None,
      **kwargs
  ):
    root_path = root_path or pg.KeyPath()
    extra_flags = extra_flags or {}
    num_examples = extra_flags.get('num_examples', None)

    def _metric_label_group(metric_metadata: dict[str, Any] | None):
      """Renders a label group for metric metadata."""
      badges = []
      if metric_metadata:
        for metric_name, metadata in metric_metadata.items():
          assert isinstance(metadata, dict), (metric_name, metadata)
          for k, v in metadata.items():
            css_class = k
            if isinstance(v, bool):
              css_class += '_true' if v else '_false'
            badge = pg.views.html.controls.Badge(
                f'{k}:{v}',
                tooltip=f'{metric_name}: {k}',
                css_classes=[css_class],
            )
            badges.append(badge)
      return pg.views.html.controls.LabelGroup(badges)

    def _render_header():
      return pg.Html.element(
          'div',
          [
              pg.Html.element(
                  'div',
                  [
                      # Previous button.
                      pg.views.html.controls.Label(   # pylint: disable=g-long-ternary
                          '◀',
                          link=f'{self.id - 1}.html',
                          css_classes=['previous'],
                      ) if self.id > 1 else None,
                      # Current example ID.
                      pg.views.html.controls.Label(
                          f'#{self.id}',
                          css_classes=['example-id'],
                      ),
                      # Next button.
                      pg.views.html.controls.Label(   # pylint: disable=g-long-ternary
                          '▶',
                          link=f'{self.id + 1}.html',
                          css_classes=['next'],
                      ) if (num_examples is None
                            or self.id < num_examples) else None,

                  ]
              ),
              pg.Html.element(
                  'div',
                  [
                      # Usage summary.
                      pg.view(  # pylint: disable=g-long-ternary
                          self.usage_summary,
                          extra_flags=dict(as_badge=True)
                      ) if self.usage_summary is not None else None,
                      # Metric metadata.
                      _metric_label_group(self.metric_metadata)
                  ],
                  css_classes=['example-container'],
              )
          ]
      )

    def _render_content():
      def _tab(label, key, default):
        field = getattr(self, key)
        if default == field:
          return None
        return pg.views.html.controls.Tab(
            label=label,
            content=view.render(
                field,
                root_path=root_path + key,
                collapse_level=None,
                **view.get_passthrough_kwargs(**kwargs),
            ),
        )
      tabs = [
          _tab('Input', 'input', pg.MISSING_VALUE),
          _tab('Output', 'output', pg.MISSING_VALUE),
          _tab('Output Metadata', 'metadata', {}),
          _tab('Error', 'error', None),
      ]
      tabs = [tab for tab in tabs if tab is not None]
      return pg.views.html.controls.TabControl(
          tabs,
          len(tabs) - 1,
      )

    return pg.Html.element(
        'div',
        [
            _render_header(),
            _render_content(),
        ],
        css_classes=['eval-example']
    )

  def _html_tree_view_summary(self, *, view, **kwargs):
    return None

  @classmethod
  def _html_tree_view_css_styles(cls) -> list[str]:
    return super()._html_tree_view_css_styles() + [
        """
        .example-container {
          display: block;
          padding: 10px;
        }
        .example-id {
          font-weight: bold;
          font-size: 40px;
          margin: 0 10px;
          vertical-align: middle;
        }
        a.previous, a.next {
          text-decoration: none;
          vertical-align: middle;
          display: inline-block;
          padding: 8px 8px;
          color: #DDD;
        }
        a.previous:hover, a.next:hover {
          background-color: #ddd;
          color: black;
        }
        /* Badge styles. */
        .eval-example .badge.is_correct_true {
          color: green;
          background-color: #dcefbe;
        }
        .eval-example .badge.is_correct_false {
          color: orange;
          background-color: #ffefc4;
        }
        .eval-example .badge.error {
          color: red;
          background-color: #fdcccc;
        }
        .eval-example .badge.score {
          color: blue;
          background-color: #c4dced;
        }
        """
    ]

